#!/usr/bin/env python2
# -*- coding:utf-8 -*-


import os
import re
import json
import datetime
import traceback
import hashlib
import functools
from contextlib import contextmanager
import collections

from flask import session as flask_session
from sqlalchemy import orm
from sqlalchemy import exc as sql_exc


from Ump.utils import get_tid, compile_str

from Ump.common.utils import inspect_func, package_error
from Ump.common import log
from Ump.common import exception

from Ump.objs.db import models

LOG = log.init_info_logger()

class SessionWrapper(object):
    '''
    '''
    def __init__(self, session=None):
        pass


    @classmethod
    def _get_model(cls, model=None):
        #if not model:
        #    model = cls.MODEL
        assert model is not None
        return model

    def _rebuild_query(self, query, skip=0, limit=None, order=None, desc=False):
        if order:
            if desc:
                query = query.order_by(order.desc())
            else:
                query = query.order_by(order)
        if limit:
            query = query.offset(skip).limit(limit)
        return query

    @inspect_func
    def get_one(self, model=None, id_or_spec=None, order=None, desc=None):
        model = self._get_model(model)
        if isinstance(id_or_spec, dict):
            self._username_to_id(id_or_spec)
        else:
            # query = self.session.query(model).filter(model.id == id_or_spec)
            id_or_spec = {'id': id_or_spec}
        id_or_spec.update(deleted=False)

        self._rebuild_query(model.query, order=order, desc=desc)
        query = model.query.filter_by(**id_or_spec)

        return query.first()

    @inspect_func
    def get_list(self, model=None, spec={}, skip=0, limit=None, order=None, desc=False):
        model = self._get_model(model)
        self._username_to_id(spec)
        res = self.get_list_2(model, spec=spec, skip=skip, limit=limit, order=order, desc=desc)
        return res

    def get_list_by_ids(self, model, ids):
        if not ids:
            return []

        if isinstance(ids, (str, unicode)):
            ids = ids.split(',')

        if not isinstance(ids, collections.Iterable):
            raise exception.InvalidParameter(ids=ids)

        res = []
        for _id in ids:
            if not _id:
                continue
            mo = self.get_one(model=model, id_or_spec=_id)
            if not mo:
                raise exception.InvalidParameter(model=str(model), id=_id)

            res.append(mo)
        return res

    def get_list_2(self, model, spec={}, skip=0, limit=None, order=None, desc=False, read_deleted=False):
        query = model.query
        if isinstance(spec, dict):
            if not read_deleted:
                spec.update(deleted=False)

            if spec.get('read_deleted') == 'all':
                spec.pop('read_deleted', None)
                spec.pop('deleted', None)

            query = query.filter_by(**spec)
        if not order:
            if hasattr(model, 'atime'):
                order = model.atime
                desc = True
        query = self._rebuild_query(query, skip, limit, order, desc)

        LOG.info('--- model %s spec %s skip %s limit %s order %s desc %s' % (model, spec, skip, limit, order, desc))
        return query.all()

    def exists(self, model=None, spec={}):
        res = self.get_list(model=model, spec=spec)
        return True if res else False

    def get_clusterid(self):
        clusters = self.get_list(models.Cluster)
        if clusters == []:
            raise Exception('请添加主机后再进行创建目标操作')
        cluster_id = clusters[0].id
        return cluster_id

    def get_oplog_export_path(self, name = None):
        res = self.get_one(models.SysconfigForUMP, id_or_spec={'name': name})
        path = res.value_setting
        return path

    def get_schedule_job(self, name = None):
        res = self.get_one(models.ScheduleJob, id_or_spec={'class_name': name})
        return res

    def _username_to_id(self, spec):
        if isinstance(spec, dict):
            username = spec.pop('username', '')
            if username:
                user = self.get_user(username)
                if user:
                    spec['user_id'] = user.id
        return spec

    def get_cluster(self, cluster_id=None):
        if not cluster_id:
            cluster_id = self.get_clusterid()
        cluster = self.get_one(models.Cluster, id_or_spec=cluster_id)
        if not cluster:
            raise exception.ClusterNotFound(cluster_id=cluster_id)
        return cluster

    def check_protocol(self, protocol='iscsi'):
        if protocol not in ['iscsi', 'nbd']:
            raise exception.ProtocolNotSupported(protocol)
        return True

    def get_user(self, username='admin'):
        user = self.get_one(model=models.User, id_or_spec={'name': username})
        return user

    def user_login(self, name, password):
        pwsh = hashlib.md5(password).hexdigest()
        return self.get_one(models.User, id_or_spec={'name': name, 'password': pwsh})

    def get_user_id(self, username):
        user = self.get_user(username=username)
        if not user:
            raise exception.UserNotFound(username)
        return user.id

    def get_pool(self, params):
        if isinstance(params, dict) and 'id' in params:
            spec = params['id']
        else:
            path = params['path'] if isinstance(params, dict) else params
            # path.ensure(umptypes.POOL)

            user = self.get_user(username=path.username)
            if not user:
                raise exception.UserNotFound(path.username)

            spec = {
                'protocol': path.protocol,
                'user_id': user.id,
                'name': path.pool_name,
            }
        pool = self.get_one(models.Pool, id_or_spec=spec)
        return pool

    def get_volume(self, params):
        if isinstance(params, dict) and 'id' in params:
            spec = params['id']
        else:
            path = params['path'] if isinstance(params, dict) else params
            # path.ensure(umptypes.VOLUME)
            # pool_name = utils.check_parameter(params, 'pool_name')
            # vol_name = utils.check_parameter(params, 'vol_name')
            pool = self.get_pool(params)
            if not pool:
                raise exception.PoolNotFound(params)

            spec = {
                'pool_id': pool.id,
                'name': path.vol_name,
            }
        return self.get_one(models.Volume, id_or_spec=spec)

    def get_snapshot(self, params):
        if isinstance(params, dict) and 'id' in params:
            spec = params['id']
        else:
            path = params['path'] if isinstance(params, dict) else params
            # path.ensure(umptypes.SNAPSHOT)

            volume = self.get_volume(params)
            if not volume:
                raise exception.VolumeNotFound(params)

            spec = {
                'volume_id': volume.id,
                'name': path.snap_name,
            }
        return self.get_one(models.Snapshot, id_or_spec=spec)

    def get_vgroup(self, params):
        user_id = self.get_user_id(params.get('username'))
        spec = {
            'name': params.get('vgroup_name'),
            'user_id':user_id,
            }
        return self.get_one(models.VGroup, id_or_spec=spec)

    def get_cgsnapshot(self, params):
        user_id = self.get_user_id(params.get('username'))
        vgroup = self.get_vgroup(params)
        spec = {
            'name': params.get('cgsnapshot_name'),
            'vgroup_id': vgroup.id,
            'user_id': user_id,
            }
        return self.get_one(models.CGSnapshot, id_or_spec=spec)

    def get_folder(self, params):
        user_id = self.get_user_id(params.get('username'))
        spec = {
            'name': params.get('folder_name'),
            'user_id':user_id,
            }
        return self.get_one(models.Folder, id_or_spec=spec)

    def get_qos(self, params):
        spec = {
            'name': params.get('name'),
            }
        return self.get_one(models.QOS, id_or_spec=spec)

    def check_pool_quota(self, pool, new_size, old_size=None):
        new_size = int(new_size)
        old_size = int(old_size) if old_size else 0

        #free = pool.disk_free()

        #if new_size - old_size >= free:
        #    raise exception.PoolQuotaError(free=free, old_size=old_size, new_size=new_size)

        return True

    def db_cluster(self, id_or_spec):
        return models.Cluster.query.first()

    def db_clusters(self, spec={}):
        return self.get_list(model=models.Cluster, spec=spec)

    def db_host(self, id_or_spec):
        return self.get_one(model=models.Host, id_or_spec=id_or_spec)

    def db_hosts(self, spec={}):
        return self.get_list(model=models.Host, spec=spec)

    def db_disk(self, id_or_spec):
        return self.get_one(model=models.Disk, id_or_spec=id_or_spec)

    def db_disks(self, spec={}):
        return self.get_list(model=models.Disk, spec=spec)

    def db_user(self, id_or_spec):
        return self.get_one(model=models.User, id_or_spec=id_or_spec)

    def db_users(self, spec={}):
        return self.get_list(model=models.User, spec=spec)

    def db_pool(self, id_or_spec):
        return self.get_one(model=models.Pool, id_or_spec=id_or_spec)

    def db_pools(self, spec={}):
        return self.get_list(model=models.Pool, spec=spec)

    def db_volume(self, id_or_spec):
        return self.get_one(model=models.Volume, id_or_spec=id_or_spec)

    def db_volumes(self, spec={}):
        return self.get_list(model=models.Volume, spec=spec)

    def db_snapshot(self, id_or_spec):
        return self.get_one(model=models.Snapshot, id_or_spec=id_or_spec)

    def db_snapshots(self, spec={}):
        return self.get_list(model=models.Snapshot, spec=spec)

    def db_vgroups(self, spec={}):
        return self.get_list(model=models.VGroup, spec=spec)

    def db_cgsnapshots(self, spec={}):
        return self.get_list(model=models.CGSnapshot, spec=spec)

    def db_oplog(self, spec={}):
        return self.get_list(model=models.Oplog, spec=spec)

_sw = SessionWrapper()


@contextmanager
def op_logger(resource='noresource', event='add', **kw):
    from Ump.objs.oplog.manager import OplogManager, ST_DONE, ST_FAILED
    logger = OplogManager(resource=resource, event=event, status=ST_DONE)
    logger.ok = True
    try:
        yield logger
    except Exception, e:
        logger.update_props(status=ST_FAILED, error_msg=e)
        LOG.info('op_logger e %s' % e)
        raise
    finally:
        if logger.ok:
            logger.create(**kw)

def enable_log_and_session(resource, event, **kw):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            with op_logger(resource=resource, event=event, oplog_obj=None) as _logger:
                if kw.get('disable_oplog', False):
                    _logger.ok = False

                try:
                    token = flask_session.get('token')
                except Exception, e:
                    token = None
                    traceback.print_exc()
                    pass

                if token:
                    user_id = token.get('user_id')
                    _logger.update_props(user_id=user_id)

                try:
                    res = func(self, _logger, *args, **kwargs)
                except Exception, e:
                    traceback.print_exc()
                    raise e
                return res
        return wrapper
    return decorator


def check_task_result(rsp):
    task_uuid = rsp.task_uuid
    error = None
    if not rsp.success:
        task = models.Task.query.filter_by(uuid=task_uuid).first()
        if task:
            error = task.detail

    return error

def get_token_from_flask_session():
    try:
        token = flask_session.get('token')
    except Exception, e:
        token = None
#        traceback.print_exc()
        pass
    return token

def parse_error_string(error):
    start_str = "failed to execute shell command:"
    recode_str = 'return code'
    cmd_regexp = '%s([\s\S]*)%s' % (start_str, recode_str)
    cmd = compile_str(cmd_regexp, error)

    recode_regexp = '%s:[\s]*([\d]*)[\s]*%s' % (recode_str, 'stdout')
    recode = compile_str(recode_regexp, error)

    stderr_rxp = '%s[\s\S]([\s\S]*)' % ('stderr:')
    stderr = compile_str(stderr_rxp, error)
    return cmd, recode, stderr
    

def enable_oplog(resource, event, **kw):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            rsp = args[0]
            with op_logger(resource=resource, event=event, oplog_obj=None) as _logger:
                if kw.get('disable_oplog', False):
                    _logger.ok = False
                    
                user_id = rsp.user_id
                if user_id:
                    _logger.update_props(user_id=user_id)

                res = None
                try:
                    res = func(self, _logger, *args, **kwargs)
                except Exception, e:
                    traceback.print_exc()
                    #raise e

                error = check_task_result(rsp)
                if error:
                    cmd, recode, stderr = parse_error_string(error)
                    package_error(recode, stderr, cmd=cmd)
                    raise exception.TaskExecuteFail(error)

                return res
        return wrapper
    return decorator


if __name__ == '__main__':
    parse_error_string
    pass

